from transformers import AutoTokenizer
import os
from tqdm import tqdm
from models.clip import tokenize_clip
import pickle
model_name_or_path = "albert-xxlarge-v2"
tokenizer = AutoTokenizer.from_pretrained(
	model_name_or_path,
	do_lower_case=True
)
max_seq_len = tokenizer.model_max_length - 2
Split2CorpusPath = {
#'val': './data/wiki-cased/en.valid.raw',
	'train': './data/wiki-cased/split.train.raw'
	
}
for key in ["train"]:

	file_header = Split2CorpusPath[key]
	directory, filename = os.path.split(file_header)
	for file_path in os.listdir(directory):
		if file_path.startswith(file_header):
			file_path = os.path.join(file_path)
			cached_features_file = os.path.join(
				directory, model_name_or_path + "_cached_lm_" + str(max_seq_len) + "_" + filename
			)
			if os.path.exists(cached_features_file):
				
				print("cache exists")
			else:
				#print("Creating features from dataset file at %s", directory,file_path)

				examples = []
				print("reading")
				with open(file_path, encoding="utf-8") as f:
					text = f.read()
				print("finish_reading")
				#print("Finish reading dataset file at %s", file_path)
				#enter = tokenizer.convert_tokens_to_ids(tokenizer.tokenize("\n"))
				#print(enter)
				#s()
				text = text.split("\n")
				tokenized_text = []
				excluded_samples = 0
				for i in tqdm(range(len(text))):
					tokenized_text += (tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text[i])))
				#print(tokenized_text)
				#print(f"Excluded {excluded_samples} samples because sentence is too long")
				print(len(tokenized_text))
				for i in tqdm(range(0, len(tokenized_text) - max_seq_len + 1, max_seq_len)):  # Truncate in block of block_size
					#print(tokenized_text[i : i + max_seq_len])
					token_ids = [101] + tokenized_text[i  : i + max_seq_len] + [102]
					assert len(token_ids) == 512
					#assert text_ids[0] == 101 and text_ids[1] == 102
					text_recon = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(tokenized_text[i : i + max_seq_len]))
					#print(text_ids)
					#assert len(text_ids) == 510
					vis_input_ids = tokenize_clip([text_recon],context_length = 77, assign_seg = 9)
					#print(i.shape[1])
					new_datum = {
						'token_ids': token_ids,
						'sent': text_recon,
						'vis_input_ids': vis_input_ids
					}
					examples.append(new_datum)
				# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
				# If your dataset is small, first you should loook for a bigger one :-) and second you
				# can change this behavior by adding (model specific) padding.

				print("Saving features into cached file %s", cached_features_file)
				with open(cached_features_file, "wb") as handle:
					pickle.dump(examples, handle, protocol=pickle.HIGHEST_PROTOCOL)